import torch
from torch import Tensor

from data_utils import row_normalize, get_normalized_degree, get_contraction_coefficient, get_max_deviations


class MeanAggregation(torch.nn.Module):
    def __init__(self, num_iterations: int):
        super(MeanAggregation, self).__init__()
        self.num_iterations = num_iterations

    def impute(
        self, x, edge_index_list, edge_weight_list, feature_transform, \
        beta, feature_mask, group_mask, epsilons
    ):
        inv_feature_transform = 1 / feature_transform
        
        # out is initialized to 0 for missing values. However, its initialization does not matter for the final
        # value at convergence
        out = torch.zeros_like(x)
        out[feature_mask] = x[feature_mask]
        
        fair_out = {}
        for epsilon in epsilons:
            fair_out[epsilon] = torch.zeros_like(x)
            fair_out[epsilon][feature_mask] = x[feature_mask]
        
        discrimination_risk = []
        fair_discrimination_risk = {}
        for epsilon in epsilons:
            fair_discrimination_risk[epsilon] = []
        
        # risk bound calculation
        n_nodes = x.shape[0]
        transformed_x = feature_transform.view(-1, 1) * x
        transformed_out = feature_transform.view(-1, 1) * out
        deviations = get_max_deviations(transformed_out)
        fair_transformed_out = {}
        for epsilon in epsilons:
            fair_transformed_out[epsilon] = feature_transform.view(-1, 1) * fair_out[epsilon]
        
        prev_risks = [torch.abs(torch.mean(transformed_out[group_mask], dim=0) \
                                               - torch.mean(transformed_out[~group_mask], dim=0))]
        risk_bounds = [torch.max(prev_risks[-1]).item()]
        
        mu = torch.zeros(x.size(1), device=x.device)
        alphas = torch.zeros(x.size(1), device=x.device)
        adj_list = []
        for f in range(x.size(1)):
            alphas[f] = get_contraction_coefficient(edge_index_list[f], edge_weight_list[f], \
                                                      beta, feature_mask[:, f], group_mask, n_nodes)
            # R: 0, Q: 1
            mu[f] = torch.sum(x[:, f][feature_mask[:, f] & ~group_mask]) / (~group_mask).sum() \
                    - torch.sum(x[:, f][feature_mask[:, f] & group_mask]) / (group_mask).sum()
            adj_list.append(torch.sparse.FloatTensor(edge_index_list[f], values=edge_weight_list[f], size=(x.size(0), x.size(0))).to(edge_index_list[f].device))

        boundaries = {}
        for epsilon in epsilons:
            boundaries[epsilon] = [mu - epsilon, mu + epsilon]

        for _ in range(self.num_iterations):
            # Diffuse current features
            for f in range(x.size(1)):
                transformed_out[:, f] = torch.sparse.mm(adj_list[f], \
                                                        transformed_out[:, f].view(-1, 1)).view(-1)
                for epsilon in epsilons:
                    fair_transformed_out[epsilon][:, f] = torch.sparse.mm(adj_list[f], \
                                                            fair_transformed_out[epsilon][:, f].view(-1, 1)).view(-1)
            transformed_out[feature_mask] = beta * transformed_out[feature_mask] \
                                        + (1 - beta) * transformed_x[feature_mask]
            deviations = get_max_deviations(transformed_out)
            for epsilon in epsilons:
                fair_transformed_out[epsilon][feature_mask] = beta * fair_transformed_out[epsilon][feature_mask] \
                                            + (1 - beta) * transformed_x[feature_mask]
            
            out = inv_feature_transform.view(-1, 1) * transformed_out
            for epsilon in epsilons:
                fair_out[epsilon] = inv_feature_transform.view(-1, 1) * fair_transformed_out[epsilon]
            
            discrimination_risk += [torch.min(torch.abs(torch.mean(out[group_mask], dim=0) \
                                               - torch.mean(out[~group_mask], dim=0))).item()]
            prev_risks += [alphas * prev_risks[-1] + 2 * deviations]
            risk_bounds += [torch.max(prev_risks[-1]).item()]
            
            for f in range(x.size(1)):
                coefficients = torch.zeros(n_nodes, device=x.device)
                if beta == 0:
                    coefficients[~feature_mask[:, f] & group_mask] = 1.0 / group_mask.sum()
                    coefficients[~feature_mask[:, f] & ~group_mask] = -1.0 / (~group_mask).sum()

                    polytope_coefficients = coefficients[~feature_mask[:, f]]
                    
                    inner_product = {}
                    for epsilon in epsilons:
                        inner_product[epsilon] = torch.dot(polytope_coefficients, fair_out[epsilon][:, f][~feature_mask[:, f]])
                        
                    delta = {}
                    for epsilon in epsilons:
                        if inner_product[epsilon] < boundaries[epsilon][0][f]:
                            delta[epsilon] = boundaries[epsilon][0][f]
                        elif inner_product[epsilon] > boundaries[epsilon][1][f]:
                            delta[epsilon] = boundaries[epsilon][1][f]
                        else:
                            delta[epsilon] = None
                    
                    for epsilon in epsilons:
                        if delta[epsilon] is not None:
                            fair_out[epsilon][:, f][~feature_mask[:, f]] -= ((inner_product[epsilon] - delta[epsilon]) \
                                                                / torch.dot(polytope_coefficients, polytope_coefficients)) \
                                                                * polytope_coefficients
                else:
                    coefficients[group_mask] = 1.0 / group_mask.sum()
                    coefficients[~group_mask] = -1.0 / (~group_mask).sum()
                    
                    polytope_coefficients = coefficients
                    
                    inner_product = {}
                    for epsilon in epsilons:
                        inner_product[epsilon] = torch.dot(polytope_coefficients, fair_out[epsilon][:, f])
                    
                    delta = {}
                    for epsilon in epsilons:
                        if inner_product[epsilon] < boundaries[epsilon][0][f]:
                            delta[epsilon] = -epsilon
                        elif inner_product[epsilon] > boundaries[epsilon][1][f]:
                            delta[epsilon] = epsilon
                        else:
                            delta[epsilon] = None
                    
                    for epsilon in epsilons:
                        if delta[epsilon] is not None:
                            fair_out[epsilon][:, f] -= ((inner_product[epsilon] - delta[epsilon]) \
                                                / torch.dot(polytope_coefficients, polytope_coefficients)) \
                                                * polytope_coefficients
            
            for epsilon in epsilons:
                fair_discrimination_risk[epsilon] += [torch.min(torch.abs(torch.mean(fair_out[epsilon][group_mask], dim=0) \
                                                   - torch.mean(fair_out[epsilon][~group_mask], dim=0))).item()]

        return out, fair_out, discrimination_risk, fair_discrimination_risk, risk_bounds, alphas
